import copy

import torch
import numpy as np
import matplotlib.pyplot as plt
from TrajHeatMap import TrajHM
from FuncHeatMap import FuncHM
from GradVarHeatMap import GradVarHM
from GradNormHeatMap import GradNormHM
import random
from Visualization import Visualization, Compute, funcaverage
from LossFunctions import ConstructLossFunctions, CritPoints
from GD import GD
from SGD import SGD
from NoisyGD import NoisyGD
from Evaluate import Evaluate
import time
import os

NumFunc = 30
NumTrainingSet = 4
RandomSeed = 1
NumRun = 1
eps = 200
bs = 4
lr = 0.2
decay_rate = 0.99
if __name__ == '__main__':
    starting_time = time.time()
    random.seed(RandomSeed)
    torch.manual_seed(RandomSeed)
    GDGeneralizationGap = torch.zeros(NumTrainingSet)
    SGDGeneralizationGap = torch.zeros(NumTrainingSet)
    NoisyGDGeneralizationGap = torch.zeros(NumTrainingSet)
    GDProduct = torch.zeros(NumTrainingSet)
    SGDProduct = torch.zeros(NumTrainingSet)
    NoisyGDProduct = torch.zeros(NumTrainingSet)
    GDTestLosses = torch.zeros(NumTrainingSet)
    SGDTestLosses = torch.zeros(NumTrainingSet)
    NoisyGDTestLosses = torch.zeros(NumTrainingSet)
    Moment = 0
    # Construct the IniSet
    IniSet_x = torch.arange(0.0, 8.01, 0.888, requires_grad=False)
    IniSet_y = torch.arange(0.0, 8.01, 0.888, requires_grad=False)
    IniSet = torch.cartesian_prod(IniSet_x, IniSet_y)
    # Construct the set of loss functions
    LossFunctions = ConstructLossFunctions(NumFunc=NumFunc, SpuriousCritPoints=CritPoints[:-1],
                                           TrueCritPoint=CritPoints[-1])
    #print(IniSet)
    #print(len(IniSet))
    #Visualization(func3)
    #S = list(range(NumFunc))
    #print(S)
    os.makedirs('Plots', exist_ok=True)
    '''
    FuncHM(LossFunctions[0], 'Func0LossHeatMap', 'Func0LossData.csv')
    FuncHM(LossFunctions[1], 'Func1LossHeatMap', 'Func1LossData.csv')
    FuncHM(LossFunctions[2], 'Func2LossHeatMap', 'Func2LossData.csv')
    FuncHM(LossFunctions[3], 'Func3LossHeatMap', 'Func3LossData.csv')
    FuncHM(LossFunctions[4], 'Func4LossHeatMap', 'Func4LossData.csv')
    FuncHM(LossFunctions[5], 'Func5LossHeatMap', 'Func5LossData.csv')
    FuncHM(funcaverage([LossFunctions[0],LossFunctions[0],LossFunctions[2],LossFunctions[4],LossFunctions[4],LossFunctions[4]]), 'TrainingLossHeatMap', 'TrainingLossData.csv')
    print('DONE!')
    '''
    FuncHM(funcaverage(LossFunctions), 'PopulationLossHeatMap', 'PopulationLossData.csv')
    for i in range(NumTrainingSet):
        #TrainingSet = random.choices(LossFunctions, k=len(LossFunctions))
        TrainingSetIDs = random.choices(range(1, len(LossFunctions)+1), k=len(LossFunctions))
        print(TrainingSetIDs)
        #TrainingSetIDs = [28, 13, 9, 6, 18, 7, 29, 10, 30, 14, 28, 19, 7, 2, 15, 1, 9, 29, 26, 1, 1, 26, 5, 18, 15, 13, 26, 13, 12, 27]
        #TrainingSetIDs = [1, 2]
        TrainingSet = []
        for id in TrainingSetIDs:
            #print(id)
            TrainingSet.append(LossFunctions[id - 1])
        #TrainingSet = [LossFunctions[0], LossFunctions[1], LossFunctions[2], LossFunctions[2], LossFunctions[1]]
        #print(TrainingSetIDs)
        #print(TrainingSet)
        os.makedirs('Plots/{}'.format(str(TrainingSetIDs)), exist_ok=True)
        #Visualization(x_s, TrainingSet, 'Plots/{}/LossFunctions{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        #Visualization(x_s, [funcaverage(TrainingSet)], 'Plots/{}/LossFunctionsAverage{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        FuncHM(funcaverage(TrainingSet), 'Plots/{}/FuncHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)), 'Plots/{}/FuncData{}.csv'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        GradNormHM(funcaverage(TrainingSet), 'Plots/{}/GradNormHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)),
               'Plots/{}/GradNormData{}.csv'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        GradVarHM(TrainingSet, 'Plots/{}/GVHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)), 'Plots/{}/GVData{}.csv'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        GDTrajList = []
        SGDTrajList = []
        NoisyGDTrajList = []
        #x = torch.tensor([0.0, 0.0], requires_grad=True)
        for init in IniSet:
            x = init.clone().detach().requires_grad_(True)
            for j in range(NumRun):
                SGDSolution, SGDPositiveRatio, SGDTraj, SGDHessian, SGDAccCov = SGD(input=x, LossFunctions=TrainingSet, eps=eps, lr=lr, decay_rate=decay_rate, bs=bs, seed=j)
                '''
                TrainingSet_alt = copy.deepcopy(TrainingSet)
                TrainingSet_alt[11] = LossFunctions[14]
                SGDSolution_alt, SGDPositiveRatio_alt, SGDTraj_alt, SGDHessian_alt, SGDAccCov_alt = SGD(input=x, LossFunctions=TrainingSet_alt,
                                                                                    eps=eps, lr=lr, decay_rate=decay_rate,
                                                                                    bs=bs, seed=j)
                print("The original SGD solution is {}".format(SGDSolution.detach()))
                print("The alternative SGD solution is {}".format(SGDSolution_alt.detach()))
                Moment += torch.norm(SGDSolution_alt-SGDSolution) ** 2
                '''
                gradient = torch.autograd.grad(funcaverage(TrainingSet)(x), x)
                #print("The Gradient of SGD is {}".format(gradient))
                #print("SGD Hessian is {}".format(SGDHessian))
                #print("SGD AccCov is {}".format(SGDAccCov))
                SGDProduct[i] += torch.trace(SGDHessian @ SGDAccCov)
                SGDTestLoss = Evaluate(LossFunctions, SGDSolution)
                SGDTrainLoss = Evaluate(TrainingSet, SGDSolution)
                #print("The Test Loss of SGD is {}".format(SGDTestLoss))
                #print("The Training Loss of SGD is {}".format(SGDTrainLoss))
                #print("The Generalization Gap of SGD is {}".format(SGDTestLoss - SGDTrainLoss))
                SGDTestLosses[i] += SGDTestLoss
                SGDGeneralizationGap[i] += SGDTestLoss - SGDTrainLoss
                SGDTrajList.append(SGDTraj)
                #TrajHM(SGDTraj[0], SGDTraj[1], 'Plots/{}/SGDHeatMap{}_{}'.format(str(TrainingSetIDs), str(TrainingSetIDs), str(j)))

                GDSolution, GDPositiveRatio, GDTraj, GDHessian, GDAccCov = GD(input=x, LossFunctions=TrainingSet, eps=eps, lr=lr, decay_rate=decay_rate)
                #print("GD Hessian is {}".format(GDHessian))
                #print("GD AccCov is {}".format(GDAccCov))
                GDProduct[i] += torch.trace(GDHessian @ GDAccCov)
                GDTestLoss = Evaluate(LossFunctions, GDSolution)
                GDTrainLoss = Evaluate(TrainingSet, GDSolution)
                #print("The Generalization Gap of GD is {}".format(GDTestLoss - GDTrainLoss))
                GDTestLosses[i] += GDTestLoss
                GDGeneralizationGap[i] += GDTestLoss - GDTrainLoss
                GDTrajList.append(GDTraj)
                #TrajHM(GDTraj[0], GDTraj[1], 'Plots/{}/GDHeatMap{}_{}'.format(str(TrainingSetIDs), str(TrainingSetIDs), str(j)))

                NoisyGDSolution, NoisyGDPositiveRatio, NoisyGDTraj, NoisyGDHessian, NoisyGDAccCov = NoisyGD(input=x, LossFunctions=TrainingSet, eps=eps, lr=lr, decay_rate=decay_rate, seed=j)
                #print("NoisyGD Hessian is {}".format(NoisyGDHessian))
                #print("NoisyGD AccCov is {}".format(NoisyGDAccCov))
                NoisyGDProduct[i] += torch.trace(NoisyGDHessian @ NoisyGDAccCov)
                NoisyGDTestLoss = Evaluate(LossFunctions, NoisyGDSolution)
                NoisyGDTrainLoss = Evaluate(TrainingSet, NoisyGDSolution)
                #print("The Generalization Gap of NoisyGD is {}".format(NoisyGDTestLoss - NoisyGDTrainLoss))
                NoisyGDTestLosses[i] += NoisyGDTestLoss
                NoisyGDGeneralizationGap[i] += NoisyGDTestLoss - NoisyGDTrainLoss
                NoisyGDTrajList.append(NoisyGDTraj)
                #TrajHM(NoisyGDTraj[0], NoisyGDTraj[1], 'Plots/{}/NoisyGDHeatMap{}_{}'.format(str(TrainingSetIDs), str(TrainingSetIDs), str(j)))
                '''
                print(x)
                print('The solution of GD is {}'.format(GDSolution.detach()))
                print('The solution of SGD is {}'.format(SGDSolution.detach()))
                print('The solution of NoisyGD is {}'.format(NoisyGDSolution.detach()))
                print('The trace product of GD is {}'.format(torch.trace(GDHessian @ GDAccCov).detach()))
                print('The trace product of SGD is {}'.format(torch.trace(SGDHessian @ SGDAccCov).detach()))
                print('The trace product of NoisyGD is {}'.format(torch.trace(NoisyGDHessian @ NoisyGDAccCov).detach()))
                print("The Generalization Gap of GD is {}".format(GDTestLoss - GDTrainLoss))
                print("The Generalization Gap of SGD is {}".format(SGDTestLoss - SGDTrainLoss))
                print("The Generalization Gap of NoisyGD is {}".format(NoisyGDTestLoss - NoisyGDTrainLoss))
                '''
                #print('The positive ratio of GD is {}'.format(GDPositiveRatio))
                #print('The positive ratio of SGD is {}'.format(SGDPositiveRatio))
                #print('The positive ratio of NoisyGD is {}'.format(NoisyGDPositiveRatio))
                #print("The Test Loss of GD is {}".format(GDTestLoss))
                #print("The Test Loss of SGD is {}".format(SGDTestLoss))
                #print("The Test Loss of NoisyGD is {}".format(NoisyGDTestLoss))
                '''
                # Plot the trajectories
                plt.figure()
                comp1 = GDTraj[0].numpy()
                comp2 = GDTraj[1].numpy()
                plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy',
                           angles='xy',
                           scale=1, color='m', label='GD')
                plt.scatter(comp1[-1], comp2[-1], s=100, facecolors='none', edgecolors='m')
                comp1 = SGDTraj[0].numpy()
                comp2 = SGDTraj[1].numpy()
                plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy',
                           angles='xy',
                           scale=1, color='r', label='SGD')
                plt.scatter(comp1[-1], comp2[-1], s=100, facecolors='none', edgecolors='r')
                comp1 = NoisyGDTraj[0].numpy()
                comp2 = NoisyGDTraj[1].numpy()
                plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy',
                           angles='xy',
                           scale=1, color='c', label='NoisyGD')
                plt.scatter(comp1[-1], comp2[-1], s=100, facecolors='none', edgecolors='c')
                plt.scatter(2, 2, s=100, facecolors='none', edgecolors='k')
                plt.scatter(2, 5, s=100, facecolors='none', edgecolors='k')
                plt.scatter(2, 8, s=100, facecolors='none', edgecolors='k')
                plt.scatter(5, 2, s=100, facecolors='none', edgecolors='k')
                plt.scatter(5, 5, s=100, facecolors='none', edgecolors='k')
                plt.scatter(5, 8, s=100, facecolors='none', edgecolors='k')
                plt.scatter(8, 2, s=100, facecolors='none', edgecolors='k')
                plt.scatter(8, 5, s=100, facecolors='none', edgecolors='k')
                plt.scatter(8, 8, s=100, facecolors='none', edgecolors='y')
                plt.title("Algorithm trajectory")
                plt.xlabel("x")
                plt.ylabel("Iteration")
                plt.legend()
                plt.savefig('Plots/{}/Traj{}_{}'.format(str(TrainingSetIDs), str(TrainingSetIDs), str(j)))
                plt.close()
                '''
        '''
        GDGeneralizationGap[i] /= NumRun
        SGDGeneralizationGap[i] /= NumRun
        NoisyGDGeneralizationGap[i] /= NumRun
        GDProduct[i] /= NumRun
        SGDProduct[i] /= NumRun
        NoisyGDProduct[i] /= NumRun
        GDTestLosses[i] /= NumRun
        SGDTestLosses[i] /= NumRun
        NoisyGDTestLosses[i] /= NumRun
        '''
        GDGeneralizationGap[i] /= len(IniSet)
        SGDGeneralizationGap[i] /= len(IniSet)
        NoisyGDGeneralizationGap[i] /= len(IniSet)
        GDProduct[i] /= len(IniSet)
        SGDProduct[i] /= len(IniSet)
        NoisyGDProduct[i] /= len(IniSet)
        GDTestLosses[i] /= len(IniSet)
        SGDTestLosses[i] /= len(IniSet)
        NoisyGDTestLosses[i] /= len(IniSet)

        print("The Generalization Gap of GD is {}".format(GDGeneralizationGap[i]))
        print("The Generalization Gap of SGD is {}".format(SGDGeneralizationGap[i]))
        print("The Generalization Gap of NoisyGD is {}".format(NoisyGDGeneralizationGap[i]))
        print("The Test Loss of GD is {}".format(GDTestLosses[i]))
        print("The Test Loss of SGD is {}".format(SGDTestLosses[i]))
        print("The Test Loss of NoisyGD is {}".format(NoisyGDTestLosses[i]))

        GDTrajHMData = torch.cat(GDTrajList, dim=1)
        SGDTrajHMData = torch.cat(SGDTrajList, dim=1)
        NoisyGDTrajHMData = torch.cat(NoisyGDTrajList, dim=1)
        print('GDTrajHMDataShape', GDTrajHMData.size())
        print('SGDTrajHMDataShape', SGDTrajHMData.size())
        print('NoisyGDTrajHMDataShape', NoisyGDTrajHMData.size())
        # GDTrajHMData's shape is (2,20100)
        np.savetxt('Plots/{}/GDTrajHMData.csv'.format(str(TrainingSetIDs)), GDTrajHMData.T.numpy(), delimiter=',')
        TrajHM(GDTrajHMData[0], GDTrajHMData[1], 'Plots/{}/GDTrajHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        np.savetxt('Plots/{}/SGDTrajHMData.csv'.format(str(TrainingSetIDs)), SGDTrajHMData.T.numpy(), delimiter=',')
        TrajHM(SGDTrajHMData[0], SGDTrajHMData[1], 'Plots/{}/SGDTrajHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)))
        np.savetxt('Plots/{}/NoisyGDTrajHMData.csv'.format(str(TrainingSetIDs)), NoisyGDTrajHMData.T.numpy(), delimiter=',')
        TrajHM(NoisyGDTrajHMData[0], NoisyGDTrajHMData[1], 'Plots/{}/NoisyGDTrajHeatMap{}'.format(str(TrainingSetIDs), str(TrainingSetIDs)))


    print("The average generalization gap of GD is {}".format(GDGeneralizationGap.mean()))
    print("The average generalization gap of SGD is {}".format(SGDGeneralizationGap.mean()))
    print("The average generalization gap of NoisyGD is {}".format(NoisyGDGeneralizationGap.mean()))
    #print("The average product of GD is {}".format(GDProduct.mean()))
    #print("The average product of SGD is {}".format(SGDProduct.mean()))
    #print("The average product of NoisyGD is {}".format(NoisyGDProduct.mean()))
    print("The average test loss of GD is {}".format(GDTestLosses.mean()))
    print("The average test loss of SGD is {}".format(SGDTestLosses.mean()))
    print("The average test loss of NoisyGD is {}".format(NoisyGDTestLosses.mean()))
    print("The average moment of SGD is {}".format(Moment / (NumTrainingSet*NumRun)))

    '''
    # print(comp1)
    # print(model.parameters().data)
    plt.figure()
    comp1 = GDTraj.numpy()
    comp2 = np.array(range(eps))
    plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy', angles='xy',
               scale=1, color='m', label='GD')
    plt.scatter(comp1[-1], comp2[-1], color='m')
    comp1 = SGDTraj.numpy()
    comp2 = np.array(range(eps))
    plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy', angles='xy',
               scale=1, color='r', label='SGD')
    plt.scatter(comp1[-1], comp2[-1], color='red')
    comp1 = NoisyGDTraj.numpy()
    comp2 = np.array(range(eps))
    plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy', angles='xy',
               scale=1, color='c', label='NoisyGD')
    plt.scatter(comp1[-1], comp2[-1], color='c')
    plt.title("Algorithm trajectory")
    plt.xlabel("x")
    plt.ylabel("Iteration")
    plt.legend()
    plt.savefig("Plots/Traj")
    plt.close()
    '''
    '''
    comp1 = SGDTraj.numpy()
    comp2 = np.array(range(eps))
    # print(comp1)
    # print(model.parameters().data)
    plt.figure()
    plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy', angles='xy',
               scale=1)
    plt.scatter(comp1[-1], comp2[-1], color='red')
    plt.title("SGD trajectory")
    plt.xlabel("x")
    plt.ylabel("Iteration")
    plt.savefig("SGDtraj")
    plt.close()

    comp1 = NoisyGDTraj.numpy()
    comp2 = np.array(range(eps))
    # print(comp1)
    # print(model.parameters().data)
    plt.figure()
    plt.quiver(comp1[:-1], comp2[:-1], comp1[1:] - comp1[:-1], comp2[1:] - comp2[:-1], scale_units='xy', angles='xy',
               scale=1)
    plt.scatter(comp1[-1], comp2[-1], color='red')
    plt.title("NoisyGD trajectory")
    plt.xlabel("x")
    plt.ylabel("Iteration")
    plt.savefig("NoisyGDtraj")
    plt.close()
    '''
    print('Runtime is {}'.format(time.time() - starting_time))
